Skip to content

Fix support for optional inputs in model.fit #21548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

neo-alex
Copy link
Contributor

@neo-alex neo-alex commented Aug 4, 2025

Here is an example of a model with 2 inputs, the second one being optional:

class OptionalInputLayer(layers.Layer):
    def __init__(self):
        super().__init__()
        self.dense = layers.Dense(2)

    def call(self, x, y=None):
        z = x if y is None else x + y
        return self.dense(z)

    def compute_output_shape(self, x_shape):
        return x_shape

i1 = Input((2,), name="input1")
i2 = Input((2,), name="input2", optional=True)
outputs = OptionalInputLayer()(i1, i2)
model = Model({"input1": i1, "input2": i2}, outputs)

With this definition, the model can be called in Jax/TF/Torch without issue, even when input2 is None:

model({"input1": np.ones((2, 2)), "input2": None})  # WORKS
model.predict_on_batch({"input1": np.ones((2, 2)), "input2": None})  # WORKS AS WELL

It is even possible to train on a batch when input2 is None:

model.compile(loss=losses.MeanSquaredError)
model.train_on_batch(x={"input1": np.ones((2, 2)), "input2": None}, y=np.ones((2, 2)))  # WORKS

However, doing the same with model.fit API is currently failing on all backends:

# Without generator
model.fit(x={"input1": np.ones((2, 2)), "input2": None}, y=np.ones((2, 2)))  # DOESN'T WORK

# With generator
data_generator = (({"input1": np.ones((2, 2)), "input2": None}, np.ones((2, 2))) for _ in range(3))
model.fit(x=data_generator)  # DOESN'T WORK EITHER

The purpose of this PR is to fix this issue (on Jax/TF/Torch), so that the last code block above becomes possible (btw. I could add 1 or 2 unit tests along those lines to demonstrate the fix but I am unsure where would be the best place for it... model_test maybe?).

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @neo-alex, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical issue where model.fit in Keras fails to properly handle optional inputs provided as None, unlike model.predict_on_batch and model.train_on_batch. My changes ensure that model.fit can correctly process models with optional None inputs across TensorFlow, JAX, and PyTorch backends, aligning its behavior with other training utilities.

Highlights

  • Enabled model.fit with Optional None Inputs: The primary goal of this PR is to allow model.fit to successfully train models where certain inputs are optionally provided as None, which was previously causing failures.
  • Enhanced tree.map_structure Handling of None: I've updated the tree.map_structure utility across Keras to include a none_is_leaf parameter. Setting this to False ensures that None values are traversed as part of the structure rather than being treated as terminal leaves, which is crucial for handling optional inputs.
  • TensorFlow Backend Adaptations: For the TensorFlow backend, I've introduced explicit conversions of tf.experimental.Optional instances to None within the train_step and ensured that None values from generators are correctly converted to tf.experimental.Optional.empty(None) when constructing tf.data.Dataset objects.
  • Improved Data Adapter Robustness: Various data adapters (e.g., ArrayDataAdapter, GeneratorDataAdapter, TFDatasetAdapter, TorchDataLoaderAdapter, GrainDatasetAdapter) have been modified to correctly handle None values when calculating sample cardinalities, creating tensor specifications, or iterating through batches, preventing errors related to missing input shapes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for optional inputs in model.fit by introducing a none_is_leaf parameter to tree.map_structure. This allows None values, which represent optional inputs, to be correctly handled across various data adapters and backends. The changes are logical and consistently applied. However, I've found a potential issue where the logic to handle TensorFlow's Optional type is missing from test_step and predict_step, which could cause problems during evaluation and prediction.

@codecov-commenter
Copy link

codecov-commenter commented Aug 4, 2025

Codecov Report

❌ Patch coverage is 63.15789% with 14 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.71%. Comparing base (7bf852c) to head (c5b636a).
⚠️ Report is 10 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/tree/dmtree_impl.py 33.33% 7 Missing and 1 partial ⚠️
...s/src/trainers/data_adapters/data_adapter_utils.py 71.42% 1 Missing and 1 partial ⚠️
...c/trainers/data_adapters/generator_data_adapter.py 0.00% 2 Missing ⚠️
...rc/trainers/data_adapters/grain_dataset_adapter.py 50.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21548      +/-   ##
==========================================
- Coverage   82.72%   82.71%   -0.01%     
==========================================
  Files         567      567              
  Lines       56264    56476     +212     
  Branches     8797     8829      +32     
==========================================
+ Hits        46544    46716     +172     
- Misses       7562     7593      +31     
- Partials     2158     2167       +9     
Flag Coverage Δ
keras 82.52% <63.15%> (-0.01%) ⬇️
keras-jax 63.79% <47.36%> (-0.14%) ⬇️
keras-numpy 58.28% <42.10%> (-0.13%) ⬇️
keras-openvino 34.63% <21.05%> (+0.06%) ⬆️
keras-tensorflow 64.23% <52.63%> (-0.12%) ⬇️
keras-torch 63.85% <47.36%> (-0.14%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR!

I am unsure where would be the best place for it... model_test maybe?).

Yes, this would be the place to test. Ideally, it would test fit, predict and evaluate.

Also, ideally, this would be tested in *_data_adapter_test.py to cover all cases.

Taking a step back, is the goal to handle the case when in the dataset passed to fit "input2" is always None? Or sometimes None sometimes not None. Right now it looks like it's only supporting the latter (always None).

@@ -32,6 +32,8 @@ def get_tf_dataset(self):
from keras.src.utils.module_utils import tensorflow as tf

def convert_to_tf(x, spec):
if isinstance(spec, tf.OptionalSpec):
return x
Copy link
Collaborator

@hertschuh hertschuh Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you just return tf.experimental.Optional.empty(None) here are remove lines 55-62?

Or tf.experimental.Optional.empty(None) is x is None else x?

Either way, lines 55-62 should move here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately no (this is what I tried first): indeed, an error is then raised by tree.map_structure on line 63 because batch and self._output_signature do not have the same structure (more specifically: None leaves in batch do not match None leaves in self._output_signature, which have tf.OptionalSpec instead). This is why I had to convert None leaves in batch first, in lines 55-62 - please let me know if you find a more elegant solution to this issue though (I am also not a fan of having 2 map.structure calls in a row if it is avoidable).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Erratum: sorry, I got confused in my own tests (there is actually no issue using tree.map_structure with None leaves on one structure and tf.OptionalSpec on another, as long as none_is_leaf=True - which is the default). So you are right and I simplified the logic according to your comment in this commit.

@@ -179,6 +179,7 @@ def map_structure(func, *structures):
Args:
func: A callable that accepts as many arguments as there are structures.
*structures: Arbitrarily nested structures of the same layout.
none_is_leaf: If True, None is treated as a leaf.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more details here? The name none_is_leaf is pretty unintuitive actually. Basically, say something like:

none_is_leaf=True causes func to be called on None leaves, and none_is_leaf=False means Nones are not passed to func and are returned in the output directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I improved its docstring accordingly in this commit. By the way, I agree that the name none_is_leaf is not the most intuitive, but it is used consistently throughout the underlying optree library (e.g. here), so I kept the same one.

if not all(s is None for s in args):
raise ValueError(
"Structure mismatch: some arguments are None, others "
"are not."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issues while running map_structure can be hard to debug. Any bit of context can help.

Can you add args?

                    raise ValueError(
                        "Structure mismatch: some arguments are None, others "
                        f"are not: {args}."
                    )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, done in this commit.

def _convert_optional_to_none(self, x):
# Convert TF Optional implementations to None
return tree.map_structure(
lambda i: None if isinstance(i, tf.experimental.Optional) else i, x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you also need to do i.get_value() if i.has_value() else None? So that you support both the None and not None cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are probably right, I will double-check (see also my reply below to your "taking a step back" comment wrt. mixing None and not None cases).

@@ -199,6 +203,8 @@ def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True):
"""
from keras.src.utils.module_utils import tensorflow as tf

if keras_tensor is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually ever happen?

My assumption was that this would need to handle non-None inputs that have optional=True on them (this might require some changes), and then create a tf.OptionalSpec(<the actual tensorspec for the KerasTensor per the code below>).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does actually happen, even if the reason is not intuitive: your assumption makes a lot of sense (ideally we would like optional inputs to be represented by KerasTensor with optional=True like in the model), unfortunately all the code in data_adapters is independent from the model, and the data spec is solely inferred from the first batches of received data (typically here)... which seems indeed a bit brittle and prone to some "hidden" constraints for the first batches of the dataset (e.g. see this error message).

Since it is not possible to infer a proper KerasTensor just from a received None value, the trick I am using is to keep it as None (by using the newly introduced none_is_leaf=False inside get_keras_tensor_spec), which explains then that the line of code you mention is actually needed.

@neo-alex
Copy link
Contributor Author

neo-alex commented Aug 7, 2025

Thank you very much @hertschuh for your insightful review! To answer your "taking a step back" comment, for now the goal is at least to enable optional inputs in model fit/evaluate/predict when they are always or never None. Of course, the best would be to enable both (sometimes None, sometimes a tensor) but this can be a bit tricky since tensor specs in data adapters are actually inferred from the first batches (at least for some backends like Tensorflow) - see my last comment above for more technical details.

I will continue to investigate and see if a solution for mixed values (None & not None) is still possible, given some constraints like "in the first 2 batches, every optional input should include a None value and a tensor one" (similar constraints are already assumed in current data adapters, as seen here). I will come back to you shortly about this.

@hertschuh
Copy link
Collaborator

Thank you very much @hertschuh for your insightful review! To answer your "taking a step back" comment, for now the goal is at least to enable optional inputs in model fit/evaluate/predict when they are always or never None. Of course, the best would be to enable both (sometimes None, sometimes a tensor) but this can be a bit tricky since tensor specs in data adapters are actually inferred from the first batches (at least for some backends like Tensorflow) - see my last comment above for more technical details.

I will continue to investigate and see if a solution for mixed values (None & not None) is still possible, given some constraints like "in the first 2 batches, every optional input should include a None value and a tensor one" (similar constraints are already assumed in current data adapters, as seen here). I will come back to you shortly about this.

You're right, the data spec and and the model inputs are disconnected, which, as you point out, is the source of a number of shortcomings. It might not be possible to mix None and not None without this connection, which will be complex to add. Unless we do the same hack as for dynamic dimensions that you linked.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants